#!/usr/bin/env python3


# 默认运行模式设置
DEFAULT_MODE = "finite"         # 可选: "infinite" 或 "finite"
DEFAULT_REQUIRED_SUCCESSES = 4    # 每个场景需要成功的次数
DEFAULT_MAX_ATTEMPTS = 10         # 有限模式下每个场景的最大尝试次数

src_dataset_path = ["/path/src_dataset"]
 # 输入演示数据path （可以>=1条）
TEST_OUT_PATH = "/path/demo" # 输出数据path（实际输出）
obj_generalize_path = "/path/objects" 
aug_generalize_file = "/path/aug_base.yaml" 
CONFIG_FILE_PATH = "/path/configs.yaml"   

task_spec_path = "/path/task_mimic.yaml"  # 任务描述yaml
HDF5_OUTPUT_PATH = "/path/gen_data.hdf5" 
log_fail_dir = "/path/log_fail"            # 失败日志文件夹路径
log_success_dir = "/path/log_success"      # 成功日志文件夹路径

import os
import sys
import yaml
import numpy as np
import h5py
import json
import time
import torch
import argparse
from datetime import datetime
from copy import deepcopy
from omnigibson.envs.data_generate import DataGenerator
from omnigibson.envs.task_spec import MG_TaskSpec
from omnigibson.envs.data_wrapper import DataMimicWrapper
from omnigibson.envs.env_aug_iter import EnvAugmentor
import yaml
from omnigibson.macros import gm
import omnigibson as og
from datetime import datetime
import os
import numpy as np
import cProfile
import pstats
import io
import omnigibson as og
from omnigibson.envs.data_wrapper import DataMimicWrapper
# from omnigibson.objects.articulated_object import ArticulatedObject
# from omnigibson.utils.config_utils import handle_config
from omnigibson.envs.data_generate import DataGenerator
from omnigibson.envs.task_spec import MG_TaskSpec
from omnigibson.envs.env_aug_iter import EnvAugmentor
import cProfile
import pstats
import io
import re
from pathlib import Path
import logging


# 在文件开头初始化全局变量用于追踪当前文件名
current_filename = [None]  # 使用列表作为可变引用
current_filename.append("data.hdf5")      


domain_cfg = None
# gm.USE_GPU_DYNAMICS = False
# gm.ENABLE_FLATCACHE = True
gm.ENABLE_TRANSITION_RULES = False

def group_yaml_files_by_number(directory_path):

    return [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.yaml')]
    

def setup_env_augmentor(default_obj_path):
    augmentor = EnvAugmentor(default_obj_config_path=default_obj_path, 
                             default_texture_path="/path/to/default_texture.jpg")
    return augmentor


def env_reset(mimic_env, env_augmentor, index):
    global domain_cfg
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 获取当前场景增强的具体信息
    current_aug = domain_cfg['aug'][index]
    
    # 文件名只使用时间戳
    filename = f"data_{timestamp}.hdf5"
    current_filename[0] = filename
    
    mimic_env.env.reset()
    ball = mimic_env.env.scene.object_registry("name", "ball0")
    obj_start_pos=ball.get_position()
    mimic_env.env.reset_for_new_task(obj_start_pos)
    
    info = env_augmentor.restore2defaultstate(mimic_env.env)
    env_augmentor.iterate_env_aug(mimic_env.env, current_aug)
    
    # 添加增强信息到info字典中
    if isinstance(info, dict):
        info['augmentation'] = current_aug
    else:
        info = {'augmentation': current_aug, 'category_object': info}
    print("当前物体：", info)
        
    #env_augmentor.iterate_env_aug(mimic_env.env, current_aug)
    mimic_env.env.robots[0].reset()
    mimic_env.env.robots[0].keep_still()
    for _ in range(10):
        og.sim.step()
    
    # 检查输出路径是否存在，如果不存在则创建
    if not os.path.exists(TEST_OUT_PATH):
        os.makedirs(TEST_OUT_PATH, exist_ok=True)
        print(f"已创建输出目录: {TEST_OUT_PATH}")
    
    # 将包含增强信息的info传递给set_h5_file_path方法
    mimic_env.set_h5_file_path(
        env=mimic_env.env,
        output_path=os.path.join(TEST_OUT_PATH, filename), 
        info=info)
    
    return info

def main(CONFIG_FILE_PATH, default_obj_path, success_log, failure_log, required_successes, max_attempts=None):
    with open(task_spec_path, "r") as f:
        task_spec = yaml.load(f, Loader=yaml.FullLoader)

    grasp_spec = MG_TaskSpec(task_spec["mimicgen"]["spec"], task_spec["mimicgen"]["init_pose"])

    with open(CONFIG_FILE_PATH, "r") as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
        cfg = cfg["sim"]

    
    global domain_cfg
    with open(aug_generalize_file, "r") as f:
        domain_cfg = yaml.load(f, Loader=yaml.FullLoader)

    robots_cfg = cfg.get("robots", [])
    robot_sensor_config = robots_cfg[0].get("sensor_config", {}) if robots_cfg else {}

    external_sensors_config = cfg.get("env", {}).get("external_sensors", [])
    filtered_sensors_config = [sensor for sensor in external_sensors_config if sensor['name'] not in ['external_sensor0', 'external_sensor1']]
    for i, item in enumerate(cfg["objects"]):
        if i not in [0, 2, 3, 4, 6, 7, 11, 12, 14, 15, 19, 24, 25, 26, 27, 28, 29]:
            continue
        if item["name"] == "ball0":
            with open(default_obj_path, "r") as f:
                current_obj_config = yaml.load(f, Loader=yaml.FullLoader)
                current_obj_config = current_obj_config["objects"][0]
                cfg["objects"][i] = deepcopy(current_obj_config)
    augmentor = setup_env_augmentor(default_obj_path)
    data_generator = DataGenerator(grasp_spec, src_dataset_path)

    mimic_env = DataMimicWrapper.create_from_hdf5(
        output_path=HDF5_OUTPUT_PATH,
        robot_obs_modalities=["proprio", "rgb", "depth"],
        robot_sensor_config=robot_sensor_config,
        external_sensors_config=filtered_sensors_config,
        n_render_iterations=10,
        only_successes=False,
        config=cfg
    )
    #complete_loop(mimic_env)
    
    # 记录每个场景的成功次数
    success_counts = [0] * len(domain_cfg['aug'])
    attempt_counts = [0] * len(domain_cfg['aug'])
    is_infinite_mode = (max_attempts is None)
    
    if is_infinite_mode:
        print(f"运行无限模式: 每个场景必须成功 {required_successes} 次")
    else:
        print(f"运行有限模式: 每个场景最多尝试 {max_attempts} 次或成功 {required_successes} 次")
    
    # 继续尝试直到所有场景都成功了规定次数或达到最大尝试次数
    all_done = False
    while not all_done:
        all_done = True  # 假设所有场景都已完成
        
        for i in range(len(domain_cfg['aug'])):
            # 检查此场景是否已经完成
            if success_counts[i] >= required_successes:
                continue
                
            # 检查是否达到最大尝试次数
            if not is_infinite_mode and attempt_counts[i] >= max_attempts:
                continue
                
            # 至少有一个场景未完成
            all_done = False
            
            index = i
            print(f"处理场景 {i+1}/{len(domain_cfg['aug'])}, "
                  f"当前成功次数: {success_counts[i]}/{required_successes}, "
                  f"尝试次数: {attempt_counts[i]}{f'/{max_attempts}' if not is_infinite_mode else ''}")
            
            # 尝试重放
            flag = False
            for j in range(10):  # 最多尝试10次
                attempt_counts[i] += 1
                info = env_reset(mimic_env, augmentor, index)
                data_generator.select_source_demos()
                ans = data_generator.select_source_demo_in_list(j)
                if not ans:
                    print(f"没有可用的源演示数据 {info}")
                    with open(failure_log, "a") as log_file:
                        log_file.write(f"没有可用的源演示数据 {info}\n")
                    break
                res = data_generator.generate(mimic_env)
                if res:
                    flag = True
                    break
            
            # 记录结果
            if flag:
                import random
                random.shuffle(data_generator.src_demos)
                success_counts[i] += 1
                current_aug = domain_cfg['aug'][index]
                aug_str = str(current_aug).replace('\n', ' ')
                print(f"场景 {i+1} 成功! 当前成功次数: {success_counts[i]}/{required_successes}")
                print(f"使用的增强配置: {aug_str}")
                with open(success_log, "a") as log_file:
                    log_file.write(f'''{default_obj_path}_{i} 重放成功 (第{success_counts[i]}次)\n''')
                    log_file.write(f'''增强配置: {aug_str}\n''')
                    log_file.write(f'''输出文件: {current_filename[0]}\n\n''')
            else:
                print(f"场景 {i+1} 失败, 尝试次数: {attempt_counts[i]}")
                with open(failure_log, "a") as log_file:
                    log_file.write(f'''{default_obj_path}_{i} 未重放成功 (尝试次数: {attempt_counts[i]})\n''')
            
            # 每处理完一个场景检查是否应该结束循环
            if not is_infinite_mode and attempt_counts[i] >= max_attempts:
                print(f"场景 {i+1} 已达到最大尝试次数 {max_attempts}")
            
            # 如果这个场景已经成功足够次数或达到最大尝试次数,跳转到下一个场景
            if success_counts[i] >= required_successes or (not is_infinite_mode and attempt_counts[i] >= max_attempts):
                break
    
    # 返回是否所有场景都成功了足够次数
    all_success = all(count >= required_successes for count in success_counts)
    print(f"所有场景成功状态: {'成功' if all_success else '失败'}")
    print(f"成功次数: {success_counts}")
    print(f"尝试次数: {attempt_counts}")
    return all_success

if __name__ == "__main__":
    # 添加命令行参数处理
    parser = argparse.ArgumentParser(description='模拟生成器')
    parser.add_argument('--mode', type=str, choices=['infinite', 'finite'], default=DEFAULT_MODE,
                      help=f'运行模式: infinite=无限尝试直到成功, finite=有最大尝试次数 (默认: {DEFAULT_MODE})')
    parser.add_argument('--successes', type=int, default=DEFAULT_REQUIRED_SUCCESSES,
                      help=f'每个场景需要成功的次数 (默认: {DEFAULT_REQUIRED_SUCCESSES})')
    parser.add_argument('--attempts', type=int, default=DEFAULT_MAX_ATTEMPTS,
                      help=f'有限模式下,每个场景的最大尝试次数 (默认: {DEFAULT_MAX_ATTEMPTS})')
    args = parser.parse_args()
    
    os.makedirs(log_fail_dir, exist_ok=True)
    os.makedirs(log_success_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    success_log_path = os.path.join(log_success_dir, f"success_log_{timestamp}.log")
    failure_log_path = os.path.join(log_fail_dir, f"failure_log_{timestamp}.log")

    with open(success_log_path, "w") as success_log:
        success_log.write(f"成功重放记录 - 开始时间: {timestamp}\n")
        success_log.write(f"模式: {args.mode}, 成功次数: {args.successes}, 最大尝试次数: {args.attempts if args.mode=='finite' else '无限'}\n")
        success_log.write("-" * 50 + "\n")
    
    with open(failure_log_path, "w") as failure_log:
        failure_log.write(f"失败重放记录 - 开始时间: {timestamp}\n")
        failure_log.write(f"模式: {args.mode}, 成功次数: {args.successes}, 最大尝试次数: {args.attempts if args.mode=='finite' else '无限'}\n")
        failure_log.write("-" * 50 + "\n")

    file_list = group_yaml_files_by_number(obj_generalize_path)
    
    # 确定最大尝试次数参数
    max_attempts = args.attempts if args.mode == 'finite' else None
    
    print(f"运行模式: {args.mode}, 成功次数: {args.successes}, 最大尝试次数: {max_attempts if max_attempts else '无限'}")
    
    all_files_success = True
    for default_obj_path in file_list:
        file_success = main(CONFIG_FILE_PATH, default_obj_path, success_log_path, failure_log_path, 
                          args.successes, max_attempts)
        all_files_success = all_files_success and file_success
        og.clear()
    
    sys.exit(0 if all_files_success else 1)  # 成功返回0，失败返回1